import torch
from torch.utils.data import TensorDataset, DataLoader

a = torch.linspace(-7, 7, 200)
b = torch.cos(a)

# mydataset = MyDataSet(a, b)
# print(mydataset.len())

# a升维？

print(a.shape)

# 升维
A = a.unsqueeze(1)
print(A.shape)

print(a.view(200, -1).shape)
# 绑定方法1
c = list(zip(A, b))
print(c)

B = b.unsqueeze(1)
# 绑定方法2
d = list(TensorDataset(A,B))
print(d)

# 数据分批
train=DataLoader(c, batch_size=20, shuffle=True)
print(len(train))
